-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[enc] Add a cat accessor to the booster. #11568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new categorical accessor (Cats()
) on the C++ booster interface, refactors the C API and Python bindings to expose category data through a shared helper, and adds end-to-end tests comparing DMatrix and Booster category retrieval.
- Added
Cats()
methods toLearner
andGBTree
to expose category containers. - Refactored C API
XGBDMatrixGetCategories
and addedXGBoosterGetCategories
with a sharedGetCategoriesImpl
. - Extracted common Python
get_categories
logic into_get_categories
and updated tests to validate Booster category access.
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
src/learner.cc | Added Cats() override on LearnerConfiguration to call through to GBM. |
src/gbm/gbtree.h | Implemented Cats() in GBTree returning the model’s category container. |
src/gbm/gbm.cc | Updated copyright years and removed unused includes. |
src/c_api/c_api.cc | Moved category JSON logic into GetCategoriesImpl ; added XGBoosterGetCategories . |
python-package/xgboost/testing/ordinal.py | Added comp_booster tests to compare Booster vs. DMatrix categories. |
python-package/xgboost/core.py | Introduced _get_categories helper and refactored DMatrix/Booster methods. |
include/xgboost/learner.h | Declared forward CatContainer , updated GetNumFeature return type, added Cats() . |
include/xgboost/gbm.h | Forward-declared CatContainer and provided default Cats() implementation. |
Comments suppressed due to low confidence (3)
include/xgboost/learner.h:231
- [nitpick] Method
Cats()
deviates from the existingGet*
naming convention (GetNumFeature
,GetFeatureTypes
). Consider renaming toGetCategories()
orGetCatContainer()
for consistency.
[[nodiscard]] virtual CatContainer const* Cats() const = 0;
include/xgboost/gbm.h:165
- [nitpick] The default implementation of
Cats()
inGradientBooster
logs a fatal error. To match otherGet*
methods and improve clarity, consider renaming this method toGetCategories()
or explicitly indicate it's unsupported.
[[nodiscard]] virtual CatContainer const* Cats() const {
src/c_api/c_api.cc:708
- When
cats.Empty()
is true, the function sets*out = nullptr
but continues and then unconditionally overrides*out
. Add areturn
after setting*out = nullptr
or restructure to avoid the override.
if (cats.Empty()) {
cc @rongou |
Related: #11088